{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import mnist_inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1. 定义神经网络的参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 100\n",
    "LEARNING_RATE_BASE = 0.8\n",
    "LEARNING_RATE_DECAY = 0.99\n",
    "REGULARIZATION_RATE = 0.0001\n",
    "TRAINING_STEPS = 3000\n",
    "MOVING_AVERAGE_DECAY = 0.99"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. 定义训练的过程并保存TensorBoard的log文件。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train(mnist):\n",
    "    #  输入数据的命名空间。\n",
    "    with tf.name_scope('input'):\n",
    "        x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')\n",
    "        y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')\n",
    "    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)\n",
    "    y = mnist_inference.inference(x, regularizer)\n",
    "    global_step = tf.Variable(0, trainable=False)\n",
    "    \n",
    "    # 处理滑动平均的命名空间。\n",
    "    with tf.name_scope(\"moving_average\"):\n",
    "        variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)\n",
    "        variables_averages_op = variable_averages.apply(tf.trainable_variables())\n",
    "   \n",
    "    # 计算损失函数的命名空间。\n",
    "    with tf.name_scope(\"loss_function\"):\n",
    "        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))\n",
    "        cross_entropy_mean = tf.reduce_mean(cross_entropy)\n",
    "        loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))\n",
    "    \n",
    "    # 定义学习率、优化方法及每一轮执行训练的操作的命名空间。\n",
    "    with tf.name_scope(\"train_step\"):\n",
    "        learning_rate = tf.train.exponential_decay(\n",
    "            LEARNING_RATE_BASE,\n",
    "            global_step,\n",
    "            mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,\n",
    "            staircase=True)\n",
    "\n",
    "        train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)\n",
    "\n",
    "        with tf.control_dependencies([train_step, variables_averages_op]):\n",
    "            train_op = tf.no_op(name='train')\n",
    "    \n",
    "    writer = tf.summary.FileWriter(\"log/modified_mnist_train.log\", tf.get_default_graph())\n",
    "    \n",
    "    # 训练模型。\n",
    "    with tf.Session() as sess:\n",
    "        tf.global_variables_initializer().run()\n",
    "        for i in range(TRAINING_STEPS):\n",
    "            xs, ys = mnist.train.next_batch(BATCH_SIZE)\n",
    "\n",
    "            if i % 1000 == 0:\n",
    "                # 配置运行时需要记录的信息。\n",
    "                run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)\n",
    "                # 运行时记录运行信息的proto。\n",
    "                run_metadata = tf.RunMetadata()\n",
    "                _, loss_value, step = sess.run(\n",
    "                    [train_op, loss, global_step], feed_dict={x: xs, y_: ys},\n",
    "                    options=run_options, run_metadata=run_metadata)\n",
    "                writer.add_run_metadata(run_metadata=run_metadata, tag=(\"tag%d\" % i), global_step=i)\n",
    "                print(\"After %d training step(s), loss on training batch is %g.\" % (step, loss_value))\n",
    "            else:\n",
    "                _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})\n",
    "                \n",
    "    writer.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3. 主函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ../datasets/MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting ../datasets/MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting ../datasets/MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting ../datasets/MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "After 1 training step(s), loss on training batch is 3.15454.\n",
      "After 1001 training step(s), loss on training batch is 0.203072.\n",
      "After 2001 training step(s), loss on training batch is 0.166225.\n"
     ]
    }
   ],
   "source": [
    "def main(argv=None): \n",
    "    mnist = input_data.read_data_sets(\"../datasets/MNIST_data\", one_hot=True)\n",
    "    train(mnist)\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
